-
Notifications
You must be signed in to change notification settings - Fork 713
Qualcomm AI Engine Direct - Add MHA2SHA pass #15438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Qualcomm AI Engine Direct - Add MHA2SHA pass #15438
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15438
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (4 Unrelated Failures)As of commit 601c14c with merge base ca4c575 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "release notes: qualcomm" |
|
Hi @cccclai, Thanks |
|
Hi, since it's a really big change, and MHA2SHA pass seems complicated, can you add a test for the pass here https://github.com/pytorch/executorch/blob/main/backends/qualcomm/tests/test_passes.py passes can be fragile, so I'm trying to make sure we have it cover in tests |
18e7db1 to
0a666d2
Compare
Thanks for pointing up. I have added a test case to check the functionality of MHA2SHA. |
| if n.target == exir_ops.edge.aten.convolution.default | ||
| ] | ||
| # Check graph structure: WQ, WK, WV should be converted to SHA | ||
| self.assertTrue(len(conv_nodes) == 25, "Convolution nodes should be splited") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the test! Is it possible to check if the numeric are the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I have added it. Thanks!
Summary:
- Integrated mha into sha pass and implemented it in qnn_preprocess
- Refactored mha in static llama
- Added support for masked softmax
- Included spin quant r3 support
- Combined the n_heads key-value cache into a single cache for each
layer to decrease the number of inputs and outputs, which enhances
performance.
- Deprecated ShiftPointer kv updater mode
- Since each layer now has its own kv cache, the v cache no longer
benefits from ShiftPointer, which previously avoided copying the new
v cache to the input v cache. To prevent user confusion, ShiftPointer
mode has been deprecated
- Applied the correct input template for smollm2 135m
- Corrected the quantization annotation for reshape
- Remove outdated code from CanonicalizeConv
0a666d2 to
0b33455
Compare
Background
We observed that quantizing and compiling the original sha model requires a significant amount of time. Switching to the mha model speeds up this process. Therefore, we investigated whether converting the mha model after quantization is feasible. However, we cannot perform this conversion during the to_edge transformation, as splitting the convolution weights to sha would require modifying the state_dict, which is not permitted at that stage. Therefore, we decided to apply this pass during qnn_preprocess.
Summary:
Results
Follow README setting, test on SM8750 with QNN 2.37. Compared the new pass
convert_mha_to_shawith original sha structure